Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Oct 23, 2025

This PR adds a new struct ParamsWithStats and functions to_chains and from_chains which is mainly meant for developers of packages that share an interface with DynamicPPL.

I would say that the main purpose of these function are to abstract away the inner details of chain construction so that this doesn't have to be duplicated everywhere. For example, there are at least four different places that feature the 'split-up-dicts-of-varnames' game for MCMCChains:

(1) AbstractMCMC.bundle_samples https://github.com/TuringLang/Turing.jl/blob/0eb8576c2c1f659aafdc1a22fc6396e0b1588a67/src/mcmc/Inference.jl#L311-L312

(2) DynamicPPL.predict

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end
variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]
return variable_names, variable_values
end

(3) This DynamicPPL test utility

function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
# Sample from the prior
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
# Extract all varnames found in any dictionary. Doing it this way guards
# against the possibility of having different varnames in different
# dictionaries, e.g. for models that have dynamic variables / array sizes
varnames = OrderedSet{VarName}()
# Convert each varinfo into an OrderedDict of vns => params.
# We have to use varname_and_value_leaves so that each parameter is a scalar
dicts = map(varinfos) do t
vals = DynamicPPL.values_as(t, OrderedDict)
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
tuples = mapreduce(collect, vcat, iters)
# The following loop is a replacement for:
# push!(varnames, map(first, tuples)...)
# which causes a stack overflow if `map(first, tuples)` is too large.
# Unfortunately there isn't a union() function for OrderedSet.
for vn in map(first, tuples)
push!(varnames, vn)
end
OrderedDict(tuples)
end
# Convert back to list
varnames = collect(varnames)
# Construct matrix of values
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
# Construct dict of varnames -> symbol
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
# Construct and return the Chains object
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
end

(4) Pathfinder.pathfinder https://github.com/mlcolab/Pathfinder.jl/blob/6389f125197110ff35ccddc10ed682e4b9ff8c12/ext/PathfinderTuringExt.jl#L49

Another benefit is that certain details, like the varname_to_symbol Dict that is stored with the chain, are implemented at the same level at which it's being used.


The eagle-eyed will notice that ParamsWithStats is effectively the same as Turing.Inference.Transition, just without the logp terms explicitly bundled in.

Furthermore, to_chains in the MCMCChainsExt is almost completely the same as bundle_samples in Turing (although perhaps implemented in a slightly simpler way).

I did it this way because I want Turing to be able to make use of this function. In an original draft I had to_chains take an array of VarInfo, and then perform the reevaluation. However, this makes it quite complicated to use this in the MCMC sampling bits of Turing.

@github-actions
Copy link
Contributor

github-actions bot commented Oct 23, 2025

Benchmark Report for Commit 2fecc74

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.6 │             1.7 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          741.8 │            44.6 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          427.4 │            54.3 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          791.2 │            36.1 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         7067.5 │            25.5 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          760.5 │            54.9 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          760.3 │             6.0 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          919.2 │             3.8 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         3969.9 │             5.8 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │         1022.2 │             8.9 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        43894.2 │             5.5 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         9053.2 │             9.8 │
│               Dynamic │    10 │    mooncake │             typed │   true │          120.8 │            12.2 │
│              Submodel │     1 │    mooncake │             typed │   true │            8.7 │             6.7 │
│                   LDA │    12 │ reversediff │             typed │   true │         1021.3 │             2.0 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

❌ Patch coverage is 96.29630% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.38%. Comparing base (9a2607b) to head (2fecc74).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/to_chains.jl 92.59% 2 Missing ⚠️
ext/DynamicPPLMCMCChainsExt.jl 98.14% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1087      +/-   ##
==========================================
+ Coverage   81.06%   81.38%   +0.32%     
==========================================
  Files          40       41       +1     
  Lines        3749     3798      +49     
==========================================
+ Hits         3039     3091      +52     
+ Misses        710      707       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm force-pushed the py/varinfos_to_chains branch from 96df1a4 to 70c3dd9 Compare October 23, 2025 18:36
@penelopeysm penelopeysm marked this pull request as ready for review October 23, 2025 18:40
@penelopeysm penelopeysm force-pushed the py/varinfos_to_chains branch from 70c3dd9 to 7049125 Compare October 23, 2025 18:47
@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1087 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1087/

@penelopeysm penelopeysm changed the title Add varinfos_to_chains function Add to_chains function Oct 23, 2025
@penelopeysm penelopeysm force-pushed the py/varinfos_to_chains branch from 7b337bd to 833cbbf Compare October 23, 2025 19:38
@penelopeysm
Copy link
Member Author

CI failures are because of #1081.

@sethaxen Tagging you too since this came up via Pathfinder!

@penelopeysm penelopeysm requested a review from sunxd3 October 23, 2025 19:58
@sethaxen
Copy link
Member

@sethaxen Tagging you too since this came up via Pathfinder!

Seems to work like a charm for Pathfinder! mlcolab/Pathfinder.jl#274

@sethaxen
Copy link
Member

I wonder, for other packages implementing new chain_types, is there an equivalent to an inverse of to_chains that could be used e.g. for predict or pointwise_loglikelihoods? Or would something like from_chains that given an object of type chain_type and a Model returns an array of ParamsWithStats also be useful to have?

@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 24, 2025

an inverse of to_chains

So there are a couple of potential stages: Chain ----> Dict{VarName,Any} --[model]--> VarInfo.

@penelopeysm penelopeysm changed the title Add to_chains function Add to_chains and from_chains function Oct 24, 2025
Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test errors look unrelated.

alternative interface might be constructor instead of from_chains (or maybe even to_chains). understood type piravy headaches, so no objection to current implementation

stats = merge(stats, (loglikelihood=DynamicPPL.getloglikelihood(varinfo),))
end
if has_prior_acc && has_likelihood_acc
stats = merge(stats, (logjoint=DynamicPPL.getlogjoint(varinfo),))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lp instead of logjoint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. I'm not sure what to do about this. I'd actually really like for it to be logjoint, because lp is a historical remnant from the time when we only had one kind of lp and the context would determine what probability it was. Now we're quite specific that lp actually really means the log joint.

Forwarding this to MCMCChains would be a breaking change, so my original idea was to keep it as logjoint in DynamicPPL, and then inside bundle_samples for MCMCChains, we would override logjoint with lp to preserve the current behaviour (until we feel happy enough to break it).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like logjoint better than lp so supportive of this

@sunxd3
Copy link
Member

sunxd3 commented Oct 27, 2025

question: not suggesting right now on any level, but maybe we could start using these for TuringLang/Turing.jl#2651?

@penelopeysm
Copy link
Member Author

Yes, I haven't tried it out, but I'm fairly certain Turing.Inference.Transition could be completely replaced with this. bundle_samples would become much shorter, basically it would call to_chains and then tack on some extra info like sampling time.

@penelopeysm penelopeysm merged commit 11b7e01 into main Oct 27, 2025
14 of 19 checks passed
@penelopeysm penelopeysm deleted the py/varinfos_to_chains branch October 27, 2025 11:43
@penelopeysm
Copy link
Member Author

Argh, now after I clicked merge, I realised that to_chains and from_chains should probably really be defined in AbstractMCMC (with the implementation for DPPL.ParamsWithStats here).

@penelopeysm
Copy link
Member Author

I'll just hotfix all of that.

penelopeysm added a commit that referenced this pull request Oct 27, 2025
@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 27, 2025

I'm kind of unsure what's the best way forward. @sunxd3 / @sethaxen Thoughts very welcome. So the ParamsWithStats struct would still live here, but in general there are a couple of options:

1. move the to_chains, from_chains functions to AbstractMCMC

The main reason why I thought this, is because some of the conversion targets I implemented here just don't make sense to have in DPPL. For example, it seems weird to rely on a function DynamicPPL.from_chains to do a conversion MCMCChains <-> Array{NamedTuple}

If we go down this route, then

  • the function should probably be defined in AbstractMCMC
  • conversion to/from NamedTuple should be in MCMCChains itself
  • conversion to/from DynamicPPL.ParamsWithStats, and maybe OrderedDict{<:VarName} would be here

For users, nothing would change compared to what this PR did, except that you'd have to call AbstractMCMC.foo instead of DynamicPPL.foo.

2. restrict the functions to only take DynamicPPL.ParamsWithStats

The second option, which I somewhat prefer because it's just easier to keep track of, is to restrict the conversion input/output type to only Array{DynamicPPL.ParamsWithStats}. so no conversion to OrderedDict or NamedTuple.

In this case, because the function works only with a DynamicPPL struct, it can just live in DynamicPPL. Compared to this PR, I would just remove the NamedTuple and OrderedDict methods.

For users, if you wanted to get an OrderedDict{VarName} out of MCMCChains, then you'd have to do it in two steps: first get the ps::DynamicPPL.ParamsWithStats, and then get the OrderedDict{VarName} as ps.params. This is marginally more complicated than before, but 7 additional characters aren't enough to make me feel bad about it.

You won't be able to get a NamedTuple out of MCMCChains, but I think for DynamicPPL purposes that is a really dodgy proposition since it will silently error down the line with non-identity-lens VarNames, so I don't mind omitting that functionality. It may be useful in other contexts, but then that can be an MCMCChains problem rather than DynamicPPL.

general point about constructors

There's a certain elegance to using constructors instead of to_chains, the reason why I kind of prefer named functions is that they're easier to document and discover (using MCMCChains; methods(Chains) shows 11 constructors, the 12th will be a bit hard to find.)

(Not trying to be argumentative, mostly documenting my rationale for future readers)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants